### Config for coef plots
library(extrafont)
loadfonts()

## Purply-ones
inner.bar.col <- "#bcbddc"
outer.bar.col <- "#756bb1"

## Red-y ones
inner.bar.col <- "#fcae91"
outer.bar.col <- "#a50f15"

inner.bar.size <- 2
outer.bar.size <- 4

my.pchs <- c(1, 0, 4, 2)

mod2df <- function(mod) {
    if("bayesglm" %in% class(mod)) {
        the.coefs <- summary(mod)$coefficients
        the.coefs <- as.data.frame(the.coefs)
    } else if (class(mod)[1] == "polr") {
        the.coefs <- summary(mod)$coefficients
        the.coefs <- as.data.frame(the.coefs)
        the.coefs <- subset(the.coefs,
                            !grepl("|",
                                   rownames(the.coefs),
                                   fixed = TRUE))
    } else if("stanreg" %in% class(mod)) {
        the.smry <- summary(mod,
                            par = "beta",
                            probs = c(0.025, 0.085,
                                      1 - 0.085, 0.975))
        the.coefs <- as.data.frame(the.smry)
    } else { ## It's a glmermod?
        the.coefs <- summary(mod)$coefficients
        the.coefs <- as.data.frame(the.coefs)
        the.coefs$mean <- the.coefs$Estimate
        
    }
    

    if ("bayesglm" %in% class(mod) | class(mod)[1] == "polr") {
        if ("bayesglm" %in% class(mod)) {
            the.coefs$mean <- the.coefs$Estimate
        } else {
            the.coefs$mean <- the.coefs$Value
        }
        the.coefs$hi <- the.coefs$mean + 1.96 * the.coefs$`Std. Error`
        the.coefs$lo <- the.coefs$mean - 1.96 * the.coefs$`Std. Error`
### 83% intervals
        the.coefs$hhi <- the.coefs$mean + 1.372 * the.coefs$`Std. Error`
        the.coefs$llo <- the.coefs$mean - 1.372 * the.coefs$`Std. Error`
    } else if ("stanreg" %in% class(mod)) {
        the.coefs$hi <- the.coefs[, "97.5%"]
        the.coefs$lo <- the.coefs[, "2.5%"]
        ### 83% intervals
        the.coefs$hhi <- the.coefs[, "91.5%"]
        the.coefs$llo <- the.coefs[, "8.5%"]
    } else { ## it s glmermod?
        the.coefs$hi <- the.coefs$Estimate + 1.96 * the.coefs$`Std. Error`
        the.coefs$lo <- the.coefs$Estimate - 1.96 * the.coefs$`Std. Error`
        the.coefs$hhi <- the.coefs$Estimate + 1.372 * the.coefs$`Std. Error`
        the.coefs$llo <- the.coefs$Estimate - 1.372 * the.coefs$`Std. Error`
        
    }
    
### Set whether significant or not
    the.coefs$isSignificant <- apply(the.coefs[,c("lo", "hi")], 1, isSig)
    return(the.coefs)
}

## Re-define this using minimal style
## Red-y ones
inner.bar.col <- "#fcae91"
outer.bar.col <- "#a50f15"

inner.bar.size <- 0.5
outer.bar.size <- 1

mycoefplot <- function(mod,
                       variable.labels = NULL,
                       dropvars = NULL,
                       ylim = NULL,
                       ylab = "",
                       sec.ylab = NULL) {
    require(ggplot2)
    require(car)

    the.coefs <- mod2df(mod)
    
    the.coefs$inner.bar.col <- ifelse(the.coefs$isSignificant,
                                      inner.bar.col,
                                      "lightgrey")
    the.coefs$outer.bar.col <- ifelse(the.coefs$isSignificant,
                                      outer.bar.col,
                                      "darkgrey")
    
### Drop unused
    if (!is.null(dropvars)) {
### This relies on the rownames being set
        if (is.character(dropvars)) {
            the.coefs <- the.coefs[setdiff(rownames(the.coefs), dropvars),]
        } else {
            the.coefs <- the.coefs[-dropvars,]
        }
    }

    if (is.null(variable.labels)) {
        the.coefs$variable <- rownames(the.coefs)
    } else {
        the.coefs$variable <- rownames(the.coefs)
        for (i in 1:length(variable.labels)) {
            old <- names(variable.labels)[i]
            the.coefs$variable[which(the.coefs$variable == old)] <- variable.labels[i]
        }
    }
    
### Ensure ylims span 0
    my.ylims <- c(min(the.coefs$lo), max(the.coefs$hi[is.finite(the.coefs$hi)]))
    ### Replace lower bound
    my.ylims[1] <- ifelse(my.ylims[1] > 0, 0, my.ylims[1])
    ## Replace upper bound
    my.ylims[2] <- ifelse(my.ylims[2] < 0, 0, my.ylims[2])
    
    ### order variables
    the.coefs <- the.coefs[order(the.coefs$mean),]
    the.coefs$variable <- factor(the.coefs$variable,
                                levels = the.coefs$variable,
                                ordered = TRUE)

    retval <- ggplot(the.coefs,
                     aes(x = variable, y = mean,
                         ymin = lo, ymax = hi)) +
        scale_x_discrete("")

    if (is.null(sec.ylab)) {
        retval <- retval +
            scale_y_continuous(ylab,
                               limits = my.ylims)
    } else {
        retval <- retval +
            scale_y_continuous(ylab,
                               ## limits = my.ylims,
                               sec.axis = sec_axis(trans = ~exp(.),
                                                   name = sec.ylab)) # ,
                                                   ## breaks = scales::log_breaks(),
                                                   ## labels = derive()))

    }

    
    retval <- retval +
        geom_linerange(aes(colour = inner.bar.col),
                       size = inner.bar.size,
                       alpha = 0.8)

    retval <- retval +
        geom_linerange(aes(ymin = llo, ymax = hhi, colour = outer.bar.col),
                       size = outer.bar.size,
                       alpha = 0.8)

    retval <- retval + 
        geom_hline(aes(yintercept=0), lty=1) +
        geom_point(aes(colour = outer.bar.col),
                   size = (inner.bar.size + outer.bar.size) * 2.5,
                   shape = my.pchs[1],
                   stroke = 1.5) +
        scale_colour_identity() +
        scale_fill_identity() + 
        theme_uksc() +
        coord_flip()
   
    return(retval)
}

mycoefplot.mult <- function(modlist,
                            variable.labels = NULL,
                            model.names = NULL,
                       dropvars = NULL,
                       ylim = NULL,
                       ylab = "",
                       sec.ylab = NULL) {
    require(ggplot2)
    require(car)
    dodge.width <- 0.75
    
    if (is.null(model.names)) {
        model.names <- paste0("(", 1:length(modlist), ")")
    }
    
    modholder <- lapply(modlist, mod2df)
    for (i in 1:length(modlist)) {
        modholder[[i]]$model.name <- model.names[i]
    }
    
    the.coefs <- do.call("rbind", modholder)
    
    the.coefs$inner.bar.col <- ifelse(the.coefs$isSignificant,
                                      inner.bar.col,
                                      "lightgrey")
    the.coefs$outer.bar.col <- ifelse(the.coefs$isSignificant,
                                      outer.bar.col,
                                      "darkgrey")
    
### Drop unused
    if (!is.null(dropvars)) {
### This relies on the rownames being set
        if (is.character(dropvars)) {
            the.coefs <- the.coefs[setdiff(rownames(the.coefs), dropvars),]
        } else { 
            the.coefs <- the.coefs[-dropvars,]
        }
    }

    if (is.null(variable.labels)) {
        the.coefs$variable <- gsub("\\d$", "", rownames(the.coefs))
    } else {
        the.coefs$variable <- gsub("\\d$", "", rownames(the.coefs))
        for (i in 1:length(variable.labels)) {
            old <- names(variable.labels)[i]
            the.coefs$variable[which(the.coefs$variable == old)] <- variable.labels[i]
        }
    }
    
    if (is.null(ylim)) {
### Ensure ylims span 0
        my.ylims <- c(min(the.coefs$lo), max(the.coefs$hi[is.finite(the.coefs$hi)]))
### Replace lower bound
        my.ylims[1] <- ifelse(my.ylims[1] > 0, 0, my.ylims[1])
        ## Replace upper bound
        my.ylims[2] <- ifelse(my.ylims[2] < 0, 0, my.ylims[2])
    } else {
        my.ylims <- ylim
    }
    
    
    ### order variables
    the.coefs <- the.coefs[order(the.coefs$mean),]
    the.coefs$variable <- factor(the.coefs$variable,
                                levels = unique(the.coefs$variable),
                                ordered = TRUE)

    ## Make the label visible only for the first element
    the.coefs$name.color <- ifelse(the.coefs$variable == rev(levels(the.coefs$variable))[1],
                                   "black",
                                   "white")
    
    retval <- ggplot(data = the.coefs,
                     aes(x = variable, y = mean, ymin = lo,
                         group = model.name,
                         ymax = hi)) +
        scale_x_discrete("")

    if (is.null(sec.ylab)) {
        retval <- retval +
            scale_y_continuous(ylab,
                               limits = my.ylims)
    } else {
        retval <- retval +
            scale_y_continuous(ylab,
                               limits = my.ylims,
                               sec.axis = sec_axis(~exp(.),
                                                   breaks = scales::log_breaks(),
                                                   name = sec.ylab))

    }

    retval <- retval + 
        geom_linerange(aes(colour = inner.bar.col),
                       position = position_dodge(width = dodge.width),
                       size=inner.bar.size, 
                       alpha = 0.8)

    retval <- retval +
        geom_linerange(aes(ymin = llo, ymax = hhi,
                           colour = outer.bar.col),
                       position = position_dodge(width = dodge.width),
                       size = outer.bar.size,
                       alpha = 0.8)


    retval <- retval +
        geom_hline(aes(x = 0, yintercept = 0), lty = 1) +
        geom_point(aes(colour = outer.bar.col, shape = model.name),
                   position = position_dodge(width = dodge.width),
                   size= (inner.bar.size + outer.bar.size) * 2.5,
                   stroke = 1.5) +
        geom_text(data = the.coefs,
                  aes(x = variable, y = hi,
                      label = model.name,
                      color = name.color),
                  position = position_dodge(width = dodge.width),
                  hjust = 0,
                  size = 2.5) +
        scale_shape_manual("Model", values = my.pchs) +
        scale_colour_identity() +
        scale_fill_identity() +
        theme_uksc() +
        coord_flip() +
        theme(legend.position = "bottom",
             legend.dir = "horizontal")

    return(retval)
}

num2judge <- function(x) {
	require(car)
	return(car:::recode(x,
		"'1'='Phillips';
'2'='Hope';
'3'='Saville';
'4'='Rodger';
'5'='Walker';
'6'='Hale';
'7'='Brown';
'8'='Mance';
'9'='Collins';
'10'='Kerr';
'11'='Clarke';
'12'='Dyson';
'13'='Neuberger';
'14'='Wilson';
'15'='Sumption';
'16'='Reed';
'17'='Carnwath';
'18'='Hughes';
'19'='Toulson';
'20'='Hodge';
'101'='Scott';
'102'='Matthew Clarke';
'103'='Carloway';
'104'='Judge';
'105'='Hamilton';
'106'='Thomas';
'107'='Kirk';
'108'='Gill';
'109'='Girvan'"))
}

ct2area <- function(x) {
	require(car)
	ct <- car:::recode(x,
		"13:14='Scots';
		15:18='N Ireland';
		c(2, 19, 22) ='Criminal';
		c(4, 8) = 'Chancery';
		c(6, 12, 39) = 'Family';
		c(3, 20, 23, 24, 25,26, 29, 30, 31, 33, 34, 35, 38, 40, 41, 42, 43) = 'Public';
		c(5, 7, 9, 10, 27, 32, 36, 37, 44) = 'Civil';
		else = NA")
	return(ct)
}

pcp <- function(mod) {
    preds <- fitted(mod) > 0.5
    actual <- mod@resp$y
    mean(preds == actual, na.rm = TRUE)
}

simple_pre <- function(mod) {
    this.pcp <- pcp(mod)
    null.pcp <- mean(as.character(mod@resp$y) ==
                     names(sort(table(mod@resp$y), decreasing = TRUE))[1])
    pre <- (this.pcp - null.pcp)/(1 - null.pcp)
    return(list(pre = pre, pcp = this.pcp))
}

pre_glmer <- function(mod1, mod2 = NULL, sim = FALSE, R = 2500) {
    if (!is.null(mod2)) {
        if (mean(class(mod1) == class(mod2)) != 1) {
            stop("Model 2 must be either NULL or of the same class as Model 1\n")
        }
    }
    if (class(mod1) %in% "glmerMod") {
        if (!(family(mod1)$link %in% c("logit", "probit", "cloglog", 
            "cauchit"))) {
            stop("PRE only calculated for models with logit, probit, cloglog or cauchit links\n")
        }
    }
    if (is.null(mod2)) {
        mod2 <- glm(update(formula(mod1), ". ~ 1"),
                    data=model.frame(mod1),
                    family = family(mod1))
    }
    pred.mod2 <- as.numeric(predict(mod2, type = "response") >= 0.5)
    pmc <- mean(mod2$y == pred.mod2)

    pred.y <- as.numeric(predict(mod1, type = "response") >= 0.5)
    pcp <- mean(pred.y == mod1@resp$y)
    pre <- (pcp - pmc)/(1 - pmc)
    pred.prob1 <- predict(mod1, type = "response")
    pred.prob2 <- predict(mod2, type = "response")

    epcp <- (1/length(pred.prob1)) *
        (sum(pred.prob1[which(mod1@resp$y == 1)]) +
         sum(1 - pred.prob1[which(mod1@resp$y == 0)]))
    epmc <- (1/length(pred.prob2)) *
        (sum(pred.prob2[which(mod2$y == 1)]) +
         sum(1 - pred.prob2[which(mod2$y == 0)]))
    epre <- (epcp - epmc)/(1 - epmc)
    if(sim) {
        b1.sim <- mvrnorm(R, fixef(mod1), vcov(mod1))
        b2.sim <- mvrnorm(R, coef(mod2), vcov(mod2))
        mod1.probs <- family(mod1)$linkinv(model.matrix(mod1) %*% 
                                           t(b1.sim))
        mod2.probs <- family(mod2)$linkinv(model.matrix(mod2) %*% 
                                           t(b2.sim))
        pmcs <- apply(mod2.probs,
                      2,
                      function(x) mean(as.numeric(x >= 0.5) == mod2$y))
        pcps <- apply(mod1.probs,
                      2,
                      function(x) mean(as.numeric(x >= 0.5) == mod1@resp$y))
        pre.sim <- (pcps - pmcs)/(1 - pmcs)
        epmc.sim <- apply(mod2.probs,
                          2,
                          function(x) {
                              (1/length(x)) * 
                                  (sum(x[which(mod2$y == 1)]) +
                                   sum(1 - x[which(mod2$y == 0)]))
                          })
        
        epcp.sim <- apply(mod1.probs,
                          2,
                          function(x) {
                              (1/length(x)) * 
                                  (sum(x[which(mod1@resp$y == 1)]) +
                                   sum(1 - x[which(mod1@resp$y == 0)]))
                          })
        
        epre.sim <- (epcp.sim - epmc.sim)/(1 - epmc.sim)
    }
    
    ret <- list()
    ret$pre <- pre
    ret$epre <- epre
    form1 <- formula(mod1)
    form2 <- formula(mod2)
    ret$m1form <- paste(form1[2], form1[1], form1[3], sep = " ")
    ret$m2form <- paste(form2[2], form2[1], form2[3], sep = " ")
    ret$pcp <- pcp
    ret$pmc <- pmc
    ret$epmc <- epmc
    ret$epcp <- epcp
    if (sim) {
        ret$pre.sim <- pre.sim
        ret$epre.sim <- epre.sim
    }
    class(ret) <- "pre"
    return(ret)
}


theme_uksc <- function() {
    theme_minimal() +
    theme(text = element_text(family = "Fira Sans Condensed Light"))
}

isSig <- function(x) {
    lo <- x[1]
    hi <- x[2]
    (lo > 0 || hi < 0)
}


trim <- function(x) {
	x <- gsub("^ +", "", x)
	x <- gsub(" +$", "", x)
	x <- gsub(" +", " ", x)
	x
}
